{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "from tqdm import tqdm\n",
    "from annoy import AnnoyIndex\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "class PreTrainedEmbeddings(object):\n",
    "    \"\"\" A wrapper around pre-trained word vectors and their use \"\"\"\n",
    "    def __init__(self, word_to_index, word_vectors):\n",
    "        \"\"\"\n",
    "        Args:\n",
    "            word_to_index (dict): mapping from word to integers\n",
    "            word_vectors (list of numpy arrays)\n",
    "        \"\"\"\n",
    "        self.word_to_index = word_to_index\n",
    "        self.word_vectors = word_vectors\n",
    "        self.index_to_word = {v: k for k, v in self.word_to_index.items()}\n",
    "\n",
    "        self.index = AnnoyIndex(len(word_vectors[0]), metric='euclidean')\n",
    "        print(\"Building Index!\")\n",
    "        for _, i in self.word_to_index.items():\n",
    "            self.index.add_item(i, self.word_vectors[i])\n",
    "        self.index.build(50)\n",
    "        print(\"Finished!\")\n",
    "        \n",
    "    @classmethod\n",
    "    def from_embeddings_file(cls, embedding_file):\n",
    "        \"\"\"Instantiate from pre-trained vector file.\n",
    "        \n",
    "        Vector file should be of the format:\n",
    "            word0 x0_0 x0_1 x0_2 x0_3 ... x0_N\n",
    "            word1 x1_0 x1_1 x1_2 x1_3 ... x1_N\n",
    "        \n",
    "        Args:\n",
    "            embedding_file (str): location of the file\n",
    "        Returns: \n",
    "            instance of PretrainedEmbeddigns\n",
    "        \"\"\"\n",
    "        word_to_index = {}\n",
    "        word_vectors = []\n",
    "\n",
    "        with open(embedding_file) as fp:\n",
    "            for line in fp.readlines():\n",
    "                line = line.split(\" \")\n",
    "                word = line[0]\n",
    "                vec = np.array([float(x) for x in line[1:]])\n",
    "                \n",
    "                word_to_index[word] = len(word_to_index)\n",
    "                word_vectors.append(vec)\n",
    "                \n",
    "        return cls(word_to_index, word_vectors)\n",
    "    \n",
    "    def get_embedding(self, word):\n",
    "        \"\"\"\n",
    "        Args:\n",
    "            word (str)\n",
    "        Returns\n",
    "            an embedding (numpy.ndarray)\n",
    "        \"\"\"\n",
    "        return self.word_vectors[self.word_to_index[word]]\n",
    "\n",
    "    def get_closest_to_vector(self, vector, n=1):\n",
    "        \"\"\"Given a vector, return its n nearest neighbors\n",
    "        \n",
    "        Args:\n",
    "            vector (np.ndarray): should match the size of the vectors \n",
    "                in the Annoy index\n",
    "            n (int): the number of neighbors to return\n",
    "        Returns:\n",
    "            [str, str, ...]: words that are nearest to the given vector. \n",
    "                The words are not ordered by distance \n",
    "        \"\"\"\n",
    "        nn_indices = self.index.get_nns_by_vector(vector, n)\n",
    "        return [self.index_to_word[neighbor] for neighbor in nn_indices]\n",
    "    \n",
    "    def compute_and_print_analogy(self, word1, word2, word3):\n",
    "        \"\"\"Prints the solutions to analogies using word embeddings\n",
    "\n",
    "        Analogies are word1 is to word2 as word3 is to __\n",
    "        This method will print: word1 : word2 :: word3 : word4\n",
    "        \n",
    "        Args:\n",
    "            word1 (str)\n",
    "            word2 (str)\n",
    "            word3 (str)\n",
    "        \"\"\"\n",
    "        vec1 = self.get_embedding(word1)\n",
    "        vec2 = self.get_embedding(word2)\n",
    "        vec3 = self.get_embedding(word3)\n",
    "\n",
    "        # now compute the fourth word's embedding!\n",
    "        spatial_relationship = vec2 - vec1\n",
    "        vec4 = vec3 + spatial_relationship\n",
    "\n",
    "        closest_words = self.get_closest_to_vector(vec4, n=4)\n",
    "        existing_words = set([word1, word2, word3])\n",
    "        closest_words = [word for word in closest_words \n",
    "                             if word not in existing_words] \n",
    "\n",
    "        if len(closest_words) == 0:\n",
    "            print(\"Could not find nearest neighbors for the computed vector!\")\n",
    "            return\n",
    "        \n",
    "        for word4 in closest_words:\n",
    "            print(\"{} : {} :: {} : {}\".format(word1, word2, word3, word4))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Building Index!\n",
      "Finished!\n"
     ]
    }
   ],
   "source": [
    "embeddings = PreTrainedEmbeddings.from_embeddings_file('data/glove/glove.6B.100d.txt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "man : he :: woman : she\n",
      "man : he :: woman : never\n",
      "man : he :: woman : her\n"
     ]
    }
   ],
   "source": [
    "embeddings.compute_and_print_analogy('man', 'he', 'woman')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fly : plane :: sail : ship\n",
      "fly : plane :: sail : vessel\n"
     ]
    }
   ],
   "source": [
    "embeddings.compute_and_print_analogy('fly', 'plane', 'sail')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cat : kitten :: dog : puppy\n",
      "cat : kitten :: dog : puppies\n",
      "cat : kitten :: dog : toddler\n"
     ]
    }
   ],
   "source": [
    "embeddings.compute_and_print_analogy('cat', 'kitten', 'dog')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "blue : color :: dog : pig\n",
      "blue : color :: dog : typical\n",
      "blue : color :: dog : adult\n",
      "blue : color :: dog : and/or\n"
     ]
    }
   ],
   "source": [
    "embeddings.compute_and_print_analogy('blue', 'color', 'dog')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "leg : legs :: hand : hands\n",
      "leg : legs :: hand : ears\n",
      "leg : legs :: hand : eyes\n"
     ]
    }
   ],
   "source": [
    "embeddings.compute_and_print_analogy('leg', 'legs', 'hand')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "toe : foot :: finger : hand\n",
      "toe : foot :: finger : kept\n",
      "toe : foot :: finger : ground\n"
     ]
    }
   ],
   "source": [
    "embeddings.compute_and_print_analogy('toe', 'foot', 'finger')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "talk : communicate :: read : correctly\n",
      "talk : communicate :: read : instructions\n"
     ]
    }
   ],
   "source": [
    "embeddings.compute_and_print_analogy('talk', 'communicate', 'read')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "blue : democrat :: red : republican\n",
      "blue : democrat :: red : congressman\n",
      "blue : democrat :: red : senator\n"
     ]
    }
   ],
   "source": [
    "embeddings.compute_and_print_analogy('blue', 'democrat', 'red')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "man : king :: woman : queen\n",
      "man : king :: woman : throne\n",
      "man : king :: woman : elizabeth\n"
     ]
    }
   ],
   "source": [
    "embeddings.compute_and_print_analogy('man', 'king', 'woman')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "man : doctor :: woman : nurse\n",
      "man : doctor :: woman : physician\n",
      "man : doctor :: woman : pregnant\n"
     ]
    }
   ],
   "source": [
    "embeddings.compute_and_print_analogy('man', 'doctor', 'woman')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fast : fastest :: small : ten\n",
      "fast : fastest :: small : registered\n",
      "fast : fastest :: small : eight\n"
     ]
    }
   ],
   "source": [
    "embeddings.compute_and_print_analogy('fast', 'fastest', 'small')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "nlpbook",
   "language": "python",
   "name": "nlpbook"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
